{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import os.path as op\n",
    "import os\n",
    "import shutil\n",
    "from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting /home/ogrisel/data/mnist/train-images-idx3-ubyte.gz\n",
      "Extracting /home/ogrisel/data/mnist/train-labels-idx1-ubyte.gz\n",
      "Extracting /home/ogrisel/data/mnist/t10k-images-idx3-ubyte.gz\n",
      "Extracting /home/ogrisel/data/mnist/t10k-labels-idx1-ubyte.gz\n"
     ]
    }
   ],
   "source": [
    "data_dir = op.expanduser(\"~/data/mnist\")\n",
    "mnist = read_data_sets(data_dir, one_hot=True)\n",
    "logs_dir = '/tmp/tensorflow_logs'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def vec_normalize(vec):\n",
    "    vec_norm = tf.sqrt(tf.reduce_sum(tf.square(vec)))\n",
    "    return vec / (vec_norm + 1e-7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.Session()\n",
    "dtype = tf.float32\n",
    "learning_rate = tf.Variable(tf.constant(0.001, dtype=dtype))\n",
    "\n",
    "\n",
    "with tf.name_scope('input'):\n",
    "    x = tf.placeholder(dtype=dtype, shape=[None, 784], name='x-input')\n",
    "    y = tf.placeholder(dtype=dtype, shape=[None, 10], name='y-input')\n",
    "\n",
    "\n",
    "with tf.name_scope('variables'):\n",
    "    W = tf.Variable(tf.truncated_normal(shape=(784, 10), stddev=0.1,\n",
    "                                        dtype=dtype),\n",
    "                    name='W')\n",
    "    tf.histogram_summary('weights', W)\n",
    "    b = tf.Variable(tf.zeros(shape=(10,), dtype=dtype), name='b')\n",
    "    tf.histogram_summary('biases', b)\n",
    "    slow_direction = tf.Variable(tf.zeros(shape=[784 * 10], dtype=dtype))\n",
    "    fast_direction = tf.Variable(tf.zeros(shape=[784 * 10], dtype=dtype))\n",
    "    dir_similarity = tf.matmul(tf.reshape(slow_direction, [1, -1]),\n",
    "                               tf.reshape(fast_direction, [-1, 1]))[0, 0]\n",
    "    tf.scalar_summary('dir_similarity', dir_similarity)\n",
    "\n",
    "\n",
    "with tf.name_scope('model'):\n",
    "    preactivations = tf.matmul(x, W) + b\n",
    "    tf.histogram_summary('preactivations', preactivations)\n",
    "    y_pred = tf.nn.softmax(preactivations)\n",
    "    tf.histogram_summary('predicted_probabilities', y_pred)\n",
    "\n",
    "\n",
    "with tf.name_scope('loss'):\n",
    "    cross_entropies = tf.nn.softmax_cross_entropy_with_logits(preactivations, y)\n",
    "    cross_entropy = tf.reduce_mean(cross_entropies, name='cross_entropy')\n",
    "    tf.scalar_summary('cross_entropy', cross_entropy)\n",
    "\n",
    "\n",
    "with tf.name_scope('accuracy'):\n",
    "    with tf.name_scope('correct_prediction'):\n",
    "        correct_prediction = tf.equal(tf.argmax(y, 1),\n",
    "                                      tf.argmax(y_pred, 1))\n",
    "    with tf.name_scope('correct_prediction'):\n",
    "        accuracy = tf.reduce_mean(tf.cast(correct_prediction, dtype))\n",
    "    tf.scalar_summary('accuracy', accuracy)\n",
    "\n",
    "\n",
    "with tf.name_scope('gradient_directions'):\n",
    "    [gW, gb] = tf.gradients(cross_entropy, [W, b])\n",
    "    gW_norm = tf.sqrt(tf.reduce_sum(tf.square(gW)))\n",
    "    g_norm = tf.sqrt(tf.reduce_sum(tf.square(gW)) + tf.reduce_sum(tf.square(gb)))\n",
    "    tf.scalar_summary('gradient norm', g_norm)\n",
    "    gW_normed = tf.reshape(gW / (gW_norm + 1e-7), [-1])\n",
    "    \n",
    "\n",
    "with tf.name_scope('updates'):\n",
    "    W_update = W.assign_add(-learning_rate * gW)\n",
    "    b_update = b.assign_add(-learning_rate * gb)\n",
    "\n",
    "    slow_rate = 0.05\n",
    "    new_slow_dir = slow_rate * gW_normed + (1 - slow_rate) * slow_direction\n",
    "    slow_dir_update = slow_direction.assign(vec_normalize(new_slow_dir))\n",
    "    slow_dir_reset = slow_direction.assign(gW_normed)\n",
    "\n",
    "    fast_rate = 0.5\n",
    "    new_fast_dir = fast_rate * gW_normed + (1 - fast_rate) * fast_direction\n",
    "    fast_dir_update = fast_direction.assign(vec_normalize(new_fast_dir))\n",
    "    fast_dir_reset = fast_direction.assign(gW_normed)\n",
    "    \n",
    "    lr_up = learning_rate.assign(2 * learning_rate)\n",
    "    lr_down = learning_rate.assign(0.1 * learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def data_dict(train=True, batch_size=128):\n",
    "    \"\"\"Make a TensorFlow feed_dict: maps data onto Tensor placeholders.\"\"\"\n",
    "    if train:\n",
    "        xs, ys = mnist.train.next_batch(batch_size)\n",
    "    else:\n",
    "        xs, ys = mnist.test.images, mnist.test.labels\n",
    "    return {x: xs.astype(np.float32), y: ys.astype(np.float32)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[3.7324717,\n",
       " 1.8400075,\n",
       " 3.6116686,\n",
       " 1.7056865,\n",
       " 2.0967669,\n",
       " 2.3198204,\n",
       " 2.0171783,\n",
       " 1.7580715,\n",
       " 2.3690248,\n",
       " 1.6259246,\n",
       " 3.3400445,\n",
       " 2.0483375,\n",
       " 1.8097279,\n",
       " 2.1330061,\n",
       " 1.8711376,\n",
       " 2.6461897,\n",
       " 3.4301701,\n",
       " 1.9341406,\n",
       " 1.4128025,\n",
       " 3.1168551,\n",
       " 3.7157907,\n",
       " 2.9077392,\n",
       " 3.325181,\n",
       " 0.89536273,\n",
       " 4.3979654,\n",
       " 0.92203724,\n",
       " 2.2007186,\n",
       " 2.2937737,\n",
       " 2.1817045,\n",
       " 1.8752966,\n",
       " 1.6373912,\n",
       " 2.0365462,\n",
       " 1.6608343,\n",
       " 2.6484132,\n",
       " 3.4957781,\n",
       " 1.9901035,\n",
       " 1.9084624,\n",
       " 2.6680474,\n",
       " 2.0449464,\n",
       " 3.0129635,\n",
       " 2.3355575,\n",
       " 2.6466174,\n",
       " 2.5199924,\n",
       " 1.9306296,\n",
       " 2.47153,\n",
       " 2.5479219,\n",
       " 2.2662649,\n",
       " 2.6136937,\n",
       " 1.7118273,\n",
       " 1.3672709,\n",
       " 3.4149623,\n",
       " 2.152194,\n",
       " 2.3145103,\n",
       " 2.0982614,\n",
       " 1.5726066,\n",
       " 5.6496277,\n",
       " 1.0098101,\n",
       " 3.9320879,\n",
       " 1.1875807,\n",
       " 1.6345842,\n",
       " 4.9190865,\n",
       " 2.1639643,\n",
       " 3.1059659,\n",
       " 2.1172271,\n",
       " 1.7091054,\n",
       " 1.7676189,\n",
       " 2.1179478,\n",
       " 3.7244837,\n",
       " 2.3459547,\n",
       " 4.1443148,\n",
       " 2.1852176,\n",
       " 3.9915304,\n",
       " 1.9832186,\n",
       " 2.7600195,\n",
       " 1.882431,\n",
       " 3.7910852,\n",
       " 0.71117669,\n",
       " 2.0551419,\n",
       " 4.1546483,\n",
       " 3.7213643,\n",
       " 1.8376471,\n",
       " 2.0418112,\n",
       " 1.8615329,\n",
       " 3.6794746,\n",
       " 1.9143579,\n",
       " 2.3036268,\n",
       " 2.2760646,\n",
       " 2.079437,\n",
       " 1.5919703,\n",
       " 2.2989087,\n",
       " 1.7845988,\n",
       " 3.588326,\n",
       " 3.0593004,\n",
       " 4.4371295,\n",
       " 4.3214245,\n",
       " 3.2552154,\n",
       " 1.5259778,\n",
       " 1.9014853,\n",
       " 2.7919619,\n",
       " 2.8022759,\n",
       " 2.0109711,\n",
       " 2.3015294,\n",
       " 2.1909297,\n",
       " 2.7841287,\n",
       " 3.3960891,\n",
       " 2.7723904,\n",
       " 2.6843922,\n",
       " 2.5696578,\n",
       " 1.873759,\n",
       " 4.0515051,\n",
       " 2.7884703,\n",
       " 3.1664705,\n",
       " 2.413384,\n",
       " 2.5950832,\n",
       " 1.4605098,\n",
       " 1.8590325,\n",
       " 3.5861123,\n",
       " 1.8300626,\n",
       " 3.8167338,\n",
       " 2.6522479,\n",
       " 4.1535492,\n",
       " 2.3777215,\n",
       " 1.7563944,\n",
       " 2.1654656,\n",
       " 2.0148871,\n",
       " 2.0356688,\n",
       " 1.4784111,\n",
       " 4.0939426]"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sess.run(tf.unpack(cross_entropies, num=128), feed_dict=data_dict(train=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(784, 784)"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def cosine_similarities(x):\n",
    "    x = vec_normalize(x)\n",
    "    return tf.matmul(x, tf.transpose(x))\n",
    "\n",
    "# flat_gW = tf.reshape(gW, [-1])\n",
    "[gWs, gbs] = tf.gradients(cross_entropies, [W, b])\n",
    "sims = cosine_similarities(gWs)\n",
    "\n",
    "# sess.run(tf.initialize_all_variables())\n",
    "sess.run(sims, feed_dict=data_dict(train=False)).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "16384"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "128 ** 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "summaries = tf.merge_all_summaries()\n",
    "shutil.rmtree(logs_dir)\n",
    "train_writer = tf.train.SummaryWriter(logs_dir + '/train', sess.graph)\n",
    "test_writer = tf.train.SummaryWriter(logs_dir + '/test')\n",
    "\n",
    "sess.run(tf.initialize_all_variables())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {
    "collapsed": false,
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy on test: 0.111, gdir similarity: 0.000, lr: 0.001000\n",
      "Accuracy on test: 0.138, gdir similarity: 0.960, lr: 0.001000\n",
      "Accuracy on test: 0.166, gdir similarity: 0.956, lr: 0.001000\n",
      "Accuracy on test: 0.216, gdir similarity: 0.938, lr: 0.001000\n",
      "Accuracy on test: 0.273, gdir similarity: 0.907, lr: 0.001000\n",
      "Accuracy on test: 0.329, gdir similarity: 0.945, lr: 0.001000\n",
      "Accuracy on test: 0.379, gdir similarity: 0.930, lr: 0.001000\n",
      "Accuracy on test: 0.421, gdir similarity: 0.949, lr: 0.001000\n",
      "Accuracy on test: 0.458, gdir similarity: 0.935, lr: 0.001000\n",
      "Accuracy on test: 0.491, gdir similarity: 0.948, lr: 0.001000\n",
      "Accuracy on test: 0.524, gdir similarity: 0.904, lr: 0.001000\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-94-0186f57f8d61>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m     29\u001b[0m              fast_dir_update, dir_similarity],\n\u001b[0;32m     30\u001b[0m             feed_dict=data_dict(train=True))\n\u001b[1;32m---> 31\u001b[1;33m         \u001b[0mtrain_writer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd_summary\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrain_summaries\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mi\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     32\u001b[0m \u001b[1;31m#         if i - last_lr_change > cool_down:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     33\u001b[0m \u001b[1;31m#             if train_dir_similarity > 0.5:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m/home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/tensorflow/python/training/summary_io.py\u001b[0m in \u001b[0;36madd_summary\u001b[1;34m(self, summary, global_step)\u001b[0m\n\u001b[0;32m    133\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msummary\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbytes\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    134\u001b[0m       \u001b[0msumm\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msummary_pb2\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mSummary\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 135\u001b[1;33m       \u001b[0msumm\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mParseFromString\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msummary\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    136\u001b[0m       \u001b[0msummary\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msumm\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    137\u001b[0m     \u001b[0mevent\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mevent_pb2\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mEvent\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mwall_time\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msummary\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0msummary\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m/home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/message.py\u001b[0m in \u001b[0;36mParseFromString\u001b[1;34m(self, serialized)\u001b[0m\n\u001b[0;32m    183\u001b[0m     \"\"\"\n\u001b[0;32m    184\u001b[0m     \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mClear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 185\u001b[1;33m     \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mMergeFromString\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mserialized\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    186\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    187\u001b[0m   \u001b[1;32mdef\u001b[0m \u001b[0mSerializeToString\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m/home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/python_message.py\u001b[0m in \u001b[0;36mMergeFromString\u001b[1;34m(self, serialized)\u001b[0m\n\u001b[0;32m   1089\u001b[0m     \u001b[0mlength\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mserialized\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1090\u001b[0m     \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1091\u001b[1;33m       \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_InternalParse\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mserialized\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mlength\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m!=\u001b[0m \u001b[0mlength\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1092\u001b[0m         \u001b[1;31m# The only reason _InternalParse would return early is if it\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1093\u001b[0m         \u001b[1;31m# encountered an end-group tag.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m/home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/python_message.py\u001b[0m in \u001b[0;36mInternalParse\u001b[1;34m(self, buffer, pos, end)\u001b[0m\n\u001b[0;32m   1125\u001b[0m         \u001b[0mpos\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnew_pos\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1126\u001b[0m       \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1127\u001b[1;33m         \u001b[0mpos\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mfield_decoder\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnew_pos\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfield_dict\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1128\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mfield_desc\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1129\u001b[0m           \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_UpdateOneofState\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfield_desc\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m/home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/decoder.py\u001b[0m in \u001b[0;36mDecodeRepeatedField\u001b[1;34m(buffer, pos, end, message, field_dict)\u001b[0m\n\u001b[0;32m    610\u001b[0m           \u001b[1;32mraise\u001b[0m \u001b[0m_DecodeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'Truncated message.'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    611\u001b[0m         \u001b[1;31m# Read sub-message.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 612\u001b[1;33m         \u001b[1;32mif\u001b[0m \u001b[0mvalue\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_InternalParse\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpos\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnew_pos\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m!=\u001b[0m \u001b[0mnew_pos\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    613\u001b[0m           \u001b[1;31m# The only reason _InternalParse would return early is if it\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    614\u001b[0m           \u001b[1;31m# encountered an end-group tag.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m/home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/python_message.py\u001b[0m in \u001b[0;36mInternalParse\u001b[1;34m(self, buffer, pos, end)\u001b[0m\n\u001b[0;32m   1125\u001b[0m         \u001b[0mpos\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnew_pos\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1126\u001b[0m       \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1127\u001b[1;33m         \u001b[0mpos\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mfield_decoder\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnew_pos\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfield_dict\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1128\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mfield_desc\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1129\u001b[0m           \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_UpdateOneofState\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfield_desc\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m/home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/decoder.py\u001b[0m in \u001b[0;36mDecodeField\u001b[1;34m(buffer, pos, end, message, field_dict)\u001b[0m\n\u001b[0;32m    631\u001b[0m         \u001b[1;32mraise\u001b[0m \u001b[0m_DecodeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'Truncated message.'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    632\u001b[0m       \u001b[1;31m# Read sub-message.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 633\u001b[1;33m       \u001b[1;32mif\u001b[0m \u001b[0mvalue\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_InternalParse\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpos\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnew_pos\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m!=\u001b[0m \u001b[0mnew_pos\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    634\u001b[0m         \u001b[1;31m# The only reason _InternalParse would return early is if it encountered\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    635\u001b[0m         \u001b[1;31m# an end-group tag.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m/home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/python_message.py\u001b[0m in \u001b[0;36mInternalParse\u001b[1;34m(self, buffer, pos, end)\u001b[0m\n\u001b[0;32m   1125\u001b[0m         \u001b[0mpos\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnew_pos\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1126\u001b[0m       \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1127\u001b[1;33m         \u001b[0mpos\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mfield_decoder\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnew_pos\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfield_dict\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1128\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mfield_desc\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1129\u001b[0m           \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_UpdateOneofState\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfield_desc\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m/home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/decoder.py\u001b[0m in \u001b[0;36mDecodeField\u001b[1;34m(buffer, pos, end, message, field_dict)\u001b[0m\n\u001b[0;32m    237\u001b[0m     \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    238\u001b[0m       \u001b[1;32mdef\u001b[0m \u001b[0mDecodeField\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpos\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mend\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmessage\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfield_dict\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 239\u001b[1;33m         \u001b[1;33m(\u001b[0m\u001b[0mfield_dict\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpos\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdecode_value\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbuffer\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpos\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    240\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mpos\u001b[0m \u001b[1;33m>\u001b[0m \u001b[0mend\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    241\u001b[0m           \u001b[1;32mdel\u001b[0m \u001b[0mfield_dict\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[1;33m]\u001b[0m  \u001b[1;31m# Discard corrupt value.\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m/home/ogrisel/.virtualenvs/py35/lib/python3.5/site-packages/google/protobuf/internal/decoder.py\u001b[0m in \u001b[0;36mInnerDecode\u001b[1;34m(buffer, pos)\u001b[0m\n\u001b[0;32m    338\u001b[0m     \u001b[1;31m# bit set, it's not a number.  In Python 2.4, struct.unpack will treat it\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    339\u001b[0m     \u001b[1;31m# as inf or -inf.  To avoid that, we treat it specially.\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 340\u001b[1;33m     if ((double_bytes[7:8] in b'\\x7F\\xFF')\n\u001b[0m\u001b[0;32m    341\u001b[0m         \u001b[1;32mand\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mdouble_bytes\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m6\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;36m7\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m>=\u001b[0m \u001b[1;34mb'\\xF0'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    342\u001b[0m         and (double_bytes[0:7] != b'\\x00\\x00\\x00\\x00\\x00\\x00\\xF0')):\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "last_lr_change = 0\n",
    "cool_down = 100\n",
    "\n",
    "for i in range(10000):\n",
    "    if i % 100 == 0:\n",
    "        # Evaluate on test set\n",
    "        test_summaries, test_acc, test_dir_similarity, lr = sess.run(\n",
    "            [summaries, accuracy, dir_similarity, learning_rate],\n",
    "            feed_dict=data_dict(train=False))\n",
    "        test_writer.add_summary(test_summaries, i)\n",
    "        print(\"Accuracy on test: %0.3f, gdir similarity: %0.3f, lr: %f\"\n",
    "              % (test_acc, test_dir_similarity, lr))\n",
    "        if lr < 1e-10:\n",
    "            print('Converged!')\n",
    "            break\n",
    "    else:\n",
    "        # Evaluate on train mini_batch\n",
    "        train_summaries, _, _, _, _, train_dir_similarity = sess.run(\n",
    "            [summaries, W_update, b_update, slow_dir_update,\n",
    "             fast_dir_update, dir_similarity],\n",
    "            feed_dict=data_dict(train=True))\n",
    "        train_writer.add_summary(train_summaries, i)\n",
    "#         if i - last_lr_change > cool_down:\n",
    "#             if train_dir_similarity > 0.5:\n",
    "#                 print(\"up\")\n",
    "#                 sess.run([lr_up, slow_dir_reset, fast_dir_reset],\n",
    "#                          feed_dict=data_dict(train=True))\n",
    "#                 last_lr_change = i\n",
    "#                 cool_down = 1000\n",
    "#             elif train_dir_similarity < 0:\n",
    "#                 print('down')\n",
    "#                 sess.run([lr_down, slow_dir_reset, fast_dir_reset],\n",
    "#                          feed_dict=data_dict(train=True))\n",
    "#                 last_lr_change = i\n",
    "#                 cool_down = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
